// Copyright © 2025 Bjango. All rights reserved.

#pragma once

#import <AudioToolbox/AudioToolbox.h>
#import <AVFoundation/AVFoundation.h>

#include <vector>
#include "DSPKernel.hpp"
#include "BufferedAudioBus.hpp"

class AUProcessHelper
{
public:
    AUProcessHelper(DSPKernel& kernel, BufferedInputBus& mainBus, BufferedInputBus& sideBus)
    : mKernel{kernel}, mMainInputBus{mainBus}, mSideInputBus{sideBus}
    {
    }
    
    void setChannelCounts(UInt32 mainChannelCount, UInt32 sideChannelCount, UInt32 outputChannelCount)
    {
        mMainChannelCount = mainChannelCount;
        mSideChannelCount = sideChannelCount;
        mInputBuffers.resize(2); // our input buffer will always be 2 channel, voice on channel 1, synth on channel 2
        mOutputBuffers.resize(outputChannelCount);
    }

    void processWithEvents(AudioBufferList* mainIn, AudioBufferList* sideIn, AudioBufferList* outBufferList, AudioTimeStamp const *timestamp, AUAudioFrameCount frameCount, AURenderEvent const *events)
    {
        AUEventSampleTime now = AUEventSampleTime(timestamp->mSampleTime);
        AUAudioFrameCount framesRemaining = frameCount;
        AURenderEvent const *nextEvent = events;

        auto callProcess = [this, mainIn, sideIn, outBufferList, frameCount]
                             (AUEventSampleTime now, AUAudioFrameCount framesToDo, AUAudioFrameCount frameOffset)
        {
            if (mMainChannelCount <= 2) {
                // if the main channel is 2 or less channels, assume a configured side chain (Logic Pro, Ableton Live, etc)
                mInputBuffers[0] = getMonoInputFromBus(mainIn, mMainChannelCount, frameOffset, frameCount, bufferForMonoMixMain(frameCount), 0);
                mInputBuffers[1] = getMonoInputFromBus(sideIn, mSideChannelCount, frameOffset, frameCount, bufferForMonoMixSide(frameCount), 0);
            } else {
                // otherwise if there's 3 or more main channels, assume something like Reaper where the side chain is just more channels after the main one
                mInputBuffers[0] = getMonoInputFromBus(mainIn, 2, frameOffset, frameCount, bufferForMonoMixMain(frameCount), 0);
                mInputBuffers[1] = getMonoInputFromBus(mainIn, mMainChannelCount - 2, frameOffset, frameCount, bufferForMonoMixSide(frameCount), 2);
            }

            // output channels
            for (int ch = 0; ch < outBufferList->mNumberBuffers; ++ch) {
                mOutputBuffers[ch] = (float*)outBufferList->mBuffers[ch].mData + frameOffset;
            }

            mKernel.process(mInputBuffers, mOutputBuffers, now, framesToDo);
        };
        
        while (framesRemaining > 0) {
            if (!nextEvent) {
                AUAudioFrameCount offset = frameCount - framesRemaining;
                callProcess(now, framesRemaining, offset);
                return;
            }
            auto timeZero = AUEventSampleTime(0);
            auto headTime = nextEvent->head.eventSampleTime;
            AUAudioFrameCount segFrames = AUAudioFrameCount(std::max(timeZero, headTime - now));

            if (segFrames > 0) {
                AUAudioFrameCount offset = frameCount - framesRemaining;
                callProcess(now, segFrames, offset);
                framesRemaining -= segFrames;
                now += AUEventSampleTime(segFrames);
            }

            nextEvent = performAllSimultaneousEvents(now, nextEvent);
        }
    }

    AURenderEvent const * performAllSimultaneousEvents(AUEventSampleTime now, AURenderEvent const *evt)
    {
        do {
            mKernel.handleOneEvent(now, evt);
            evt = evt->head.next;
        } while (evt && evt->head.eventSampleTime <= now);
        return evt;
    }
    
    AUInternalRenderBlock internalRenderBlock()
    {
        return ^AUAudioUnitStatus(AudioUnitRenderActionFlags   *actionFlags,
                                  const AudioTimeStamp         *timestamp,
                                  AUAudioFrameCount             frameCount,
                                  NSInteger                     outputBusNumber,
                                  AudioBufferList              *outputData,
                                  const AURenderEvent          *realtimeEvents,
                                  AURenderPullInputBlock pullInputBlock)
        {
            if (frameCount > mKernel.maximumFramesToRender())
                        return kAudioUnitErr_TooManyFramesToProcess;

            AudioUnitRenderActionFlags pullFlags = 0;

            // Pull both input buses
            AUAudioUnitStatus err = mMainInputBus.pullInput(&pullFlags, timestamp, frameCount, 0, pullInputBlock);
            if (err) return err;

            err = mSideInputBus.pullInput(&pullFlags, timestamp, frameCount, 1, pullInputBlock);
            if (err) return err;

            AudioBufferList* mainList = mMainInputBus.mutableAudioBufferList;
            AudioBufferList* sideList = mSideInputBus.mutableAudioBufferList;

            // Output pointer validation
            if (!outputData || !outputData->mBuffers[0].mData) {
                return kAudioUnitErr_InvalidProperty;
            }

            processWithEvents(mainList, sideList, outputData, timestamp, frameCount, realtimeEvents);
            return noErr;
        };
    }

private:
    float* bufferForMonoMixMain(UInt32 frameCount) {
        if (monoMixBufferMain.size() < frameCount) {
            monoMixBufferMain.resize(frameCount);
        }
        
        return monoMixBufferMain.data();
    }

    float* bufferForMonoMixSide(UInt32 frameCount) {
        if (monoMixBufferSide.size() < frameCount) {
            monoMixBufferSide.resize(frameCount);
        }
        
        return monoMixBufferSide.data();
    }
    
    const float* getMonoInputFromBus(const AudioBufferList* bufferList, UInt32 channelCount, UInt32 frameOffset, UInt32 frameCount, float* scratchBuffer, int startFromIndex) {
        if (channelCount == 0) { return nullptr; } // not much we can do with no channels
        
        if (channelCount == 1) {
            return (const float*)bufferList->mBuffers[startFromIndex].mData + frameOffset;
        } else {
            const float* left  = (const float*)bufferList->mBuffers[startFromIndex].mData + frameOffset;
            const float* right = (const float*)bufferList->mBuffers[startFromIndex + 1].mData + frameOffset;
            
            // some apps like logic will pass mono in as left channel stereo, so check for that here
            bool isFakeStereo = true;
            for (UInt32 i = 0; i < frameCount; ++i) {
                if (fabsf(right[i]) > 1e-6f) {
                    isFakeStereo = false;
                    break;
                }
            }
            
            for (UInt32 i = 0; i < frameCount; ++i) {
                scratchBuffer[i] = isFakeStereo ? left[i] : 0.5f * (left[i] + right[i]);
            }
            return scratchBuffer;
        }
    }
    
    DSPKernel& mKernel;
    BufferedInputBus& mMainInputBus;
    BufferedInputBus& mSideInputBus;
    std::vector<const float*> mInputBuffers;
    std::vector<float*> mOutputBuffers;
    
    std::vector<float> monoMixBufferMain;
    std::vector<float> monoMixBufferSide;
    
    UInt32 mMainChannelCount, mSideChannelCount;
};
